DeepEM: A Deep Neural Network for DEM Inversion

by Paul Wright$^{1}$, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey

$^{1}$ University of Glasgow; email: paul@pauljwright.co.uk

The intensity observed through optically-thin SDO/AIA filters (94 Ã…, 131 Ã…, 171 Ã…, 193 Ã…, 211 Ã…, 335 Ã…) can be related to the temperature distribution of the solar corona (the differential emission measure; DEM) as

\begin{equation} g_{i} = \int_{T} K_{i}(T) \xi(T) dT \, . \end{equation}

In this equation, $g_{i}$ is the DN s$^{-1}$ px$^{-1}$ value in the $i$th SDO/AIA channel. This intensity corresponds to the $K_{i}(T)$ temperature response function, and the DEM, $\xi(T)$, is in units of cm$^{-5}$ K$^{-1}$. The matrix formulation of this integral equation can be represented in the form, $\vec{g} = {\bf K}\vec{\xi}$, however, this problem is an ill-posed inverse problem, and any attempt to directly recover $\vec{\xi}$ leads to significant noise amplication.

There are numerous methods to tackle mathematical problems of this kind, and there are an increasing number of methods in the literature for recovering the differential emission measure including a methods based tecniques such as Tikhonov Regularisation (Hannah & Kontar 2012), on the concept of sparsity (Cheung et al 2015). In the following notebook, we will demonstrate how a simple 1x1 2D convolutional neural network allows for significant improvement in computational speed for DEM inversion with similar fidelity to the method used for training (Basis Pursuit). Additionally this method, DeepEM, provides solutions with values of emission measure >0 in every temperature bin.

DeepEM: A Deep Learning Approach for DEM Inversion

Paul J. Wright, Mark Cheung, Rajat Thomas, Richard Galvez, Alexandre Szenicer, Meng Jin, Andres Munoz-Jaramillo, and David Fouhey


In this chapter we will introduce a Deep Learning approach for DEM Inversion. For this notebook, DeepEM is a trained on one set of SDO/AIA observations (six optically thin channels; $6 \times N \times N$) and DEM solutions (in 18 temperature bins from log$_{10}$T = 5.5 - 7.2, $18 \times N \times N$; Cheung et al 2015) at a resolution of $512 \times 512$ ($N = 512$) using a $1 \times 1$ 2D Convolutional Neural Network with a single hidden layer.

The DeepEM method presented here takes every DEM solution with no regards to the quality or existence of the solution. As will be demonstrated, when this method is trained with a single set images and DEM solutions, the DeepEM solutions have a similar fidelity to Sparse Inversion (with a significantly increased computation speed), and additionally, the DeepEM solutions find positive solutions at every pixel, and reduced noise in the DEM solutions.

In [1]:
import os
import json
import time
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from scipy.io import readsav
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torch.autograd import Variable
from torch.utils.data import DataLoader

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [2]:
def em_scale(y):
    return np.sqrt(y/1e25)

def em_unscale(y):
    return 1e25*(y*y)

def img_scale(x):
    x2 = x
    bad = np.where(x2 <= 0.0)
    x2[bad] = 0.0
    return np.sqrt(x2)

def img_unscale(x):
    return x*x 

Step 1: Obtain Data and Sparse Inversion Solutions for Training

We first load the SDO/AIA images and Basis Pursuit DEM maps.

N.B. While this simplified version of DeepEM has been trained on DEM maps from Basis Pursuit (Cheung et al. 2015), we actively encourage the readers to try their favourite method for DEM inversion!

In [3]:
aia_files = ['AIA_DEM_2011-01-27','AIA_DEM_2011-02-22','AIA_DEM_2011-03-20']
em_cube_files = aia_files

for k, (afile, emfile) in enumerate(zip(aia_files, em_cube_files)):
    afile_name = os.path.join('./DeepEM_Data/', afile + '.aia.npy')
    emfile_name = os.path.join('./DeepEM_Data/', emfile + '.emcube.npy')
    if k == 0:
        X = np.load(afile_name)
        y = np.load(emfile_name)
 
        X = np.zeros((len(aia_files), X.shape[0], X.shape[1], X.shape[2]))
        y = np.zeros((len(em_cube_files), y.shape[0], y.shape[1], y.shape[2]))
        
        nlgT = y.shape[0]
        lgtaxis = np.arange(y.shape[1])*0.1 + 5.5
        
    X[k] = np.load(afile_name)
    y[k] = np.load(emfile_name) 

Step 2: Define the Model

We first define the model as a 1x1 2D Convolutional Neural Network (CNN) with a kernel size of 1x1. The model accepts a data cube of $6 \times N \times N$ (SDO/AIA data), and returns a data cube of $18 \times N \times N$ (DEM). which when trained, will transform the input (each pixel of the 6 SDO/AIA channels; $6 \times 1 \times 1$) to the output (DEM at each pixel; $18 \times 1 \times 1$).

In [4]:
model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1)).cuda() #Loading model on to gpu

Step 3: Train the Model

For training our CNN we select one SDO/AIA data cube ($6\times512\times512$) and the corresponding Sparse Inversion DEM output ($18\times512\times512$). In the case presented here, we train the CNN on an image of the Sun obtained on 27-01-2011, validate on an image of the Sun obtained one synodic rotation later (+26 days; 22-02-2011), and finally test on an image another 26 days later (20-03-2011).

In [5]:
X = img_scale(X)
y = em_scale(y)

X_train = X[0:1] 
y_train = y[0:1] 

X_val = X[1:2] 
y_val = y[1:2] 

X_test = X[2:3] 
y_test = y[2:3]

Plotting SDO/AIA Observations ${\it vs.}$ Basis Pursuit DEM bins

For the test data set, the SDO/AIA images for 171 Ã…, 211 Ã…, and 94 Ã…, and the corresponding DEM bins near the peak sensitivity in these relative isothermal channel (logT = 6.3, 5.9) are shown in Figure 1. Figure 1 shows a set of SDO/AIA images (171 Ã…, 211 Ã…, and 94 Ã… [Left to Right]) with the corresponding DEM maps for temperature bins there are near the peak sensitivity of the SDO/AIA channel. Furthermore, it is clear from the DEM maps that a number of pixels that are $zero$. These pixels are primarily located off-disk, but there are a number of pixels on-disk that show this behaviour.

In [6]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(y_test[0,8,:,:],vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(y_test[0,4,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(y_test[0,15,:,:],vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 1: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below. In the DEM bins (bottom) it is clear that there are some pixels that have solutions of DEM = $zero$, as explicitly seen as dark regions/clusters of pixels on and off disk.


To implement training and testing of our model, we first define a DEMdata class, and define functions for training and validation/test: train_model, and valtest_model.

N.B. It is not necessary to train the model, and if required, the trained model can be loaded to the cpu as follows:

model = nn.Sequential(
    nn.Conv2d(6, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 300, kernel_size=1),
    nn.LeakyReLU(),
    nn.Conv2d(300, 18, kernel_size=1))

dem_model_file = 'DeepEM_CNN_HelioML.pth'
model.load_state_dict(torch.load(dem_model_file, map_location='cpu'))

Once you have loaded the the model, skip to Step 4: Testing the Model.

In [7]:
class DEMdata(nn.Module):
    def __init__(self, xtrain, ytrain, xtest, ytest, xval, yval, split='train'):
        
        if split == 'train':
            self.x = xtrain
            self.y = ytrain
        if split == 'val':
            self.x = xval
            self.y = yval
        if split == 'test':
            self.x = xtest
            self.y = ytest
            
    def __getitem__(self, index):
        return torch.from_numpy(self.x[index]).type(torch.FloatTensor), torch.from_numpy(self.y[index]).type(torch.FloatTensor)

    def __len__(self):
        return self.x.shape[0]
In [8]:
def train_model(dem_loader, criterion, optimizer, epochs=500):
    model.train()
    train_loss_all_batches = []
    train_loss_epoch = []
    train_val = []
    for k in range(epochs):
        count_ = 0
        avg_loss = 0
        # =================== progress indicator ==============
        if k % ((epochs + 1) // 4) == 0:
            print('[{0}]: {1:.1f}% complete: '.format(k, k / epochs * 100))
        # =====================================================
        for img, dem in dem_loader:
            count_ += 1
            optimizer.zero_grad()
            # =================== forward =====================
            img = img.cuda()
            dem = dem.cuda()

            output = model(img) 
            loss = criterion(output, dem)

            loss.backward()
            optimizer.step()
            
            train_loss_all_batches.append(loss.item())
            avg_loss += loss.item()
        # =================== Validation ===================
        dem_data_val = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='val')
        dem_loader_val = DataLoader(dem_data_val, batch_size=1)
        val_loss, dummy, dem_pred_val, dem_in_test_val = valtest_model(dem_loader_val, criterion)
        
        train_loss_epoch.append(avg_loss/count_)
        train_val.append(val_loss)
        
        if k>0:
            print('Epoch: ', k, 'trn_loss: ', avg_loss/count_, 'val_loss: ', train_val[k-1])
        else:
            print('Epoch: ', k, 'trn_loss: ', avg_loss/count_)
            
    torch.save(model.state_dict(), 'DeepEM_CNN_HelioML.pth')
    return train_loss_epoch, train_val

def valtest_model(dem_loader, criterion):

    model.eval()
    
    val_loss = 0
    count = 0
    test_loss = []
    lossarr, loss2arr, loss3arr = [], [], []
    for img, dem in dem_loader:
        count += 1
        # =================== forward =====================
        img = img.cuda()
        dem = dem.cuda()
        
        output = model(img)
        loss = criterion(output, dem)
        test_loss.append(loss.item())
        val_loss += loss.item()
        
    return val_loss/count, test_loss, output, dem

We choose the Adam optimiser with a learning rate of 1e-4, and weight_decay set to 1e-9. We use Mean Squared Error (MSE) between the Sparse Inversion DEM map and the DeepEM map as our loss function.

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-9); 
criterion = nn.MSELoss().cuda()

Using the defined functions, dem_data will return the training data, and this will be loaded by the DataLoader with batch_size=1 (one 512 x 512 image per batch). For each epoch, train_loss and valdn_loss will be returned by train_model

In [10]:
dem_data = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='train')
dem_loader = DataLoader(dem_data, batch_size=1)

t0=time.time() #Timing how long it takes to predict the DEMs
train_loss, valdn_loss = train_model(dem_loader, criterion, optimizer, epochs=500)
ttime = "Training time = {0} seconds".format(time.time()-t0)
print(ttime)
[0]: 0.0% complete: 
Epoch:  0 trn_loss:  2.578221321105957
Epoch:  1 trn_loss:  2.3965187072753906 val_loss:  2.7426037788391113
Epoch:  2 trn_loss:  2.22440242767334 val_loss:  2.552943229675293
Epoch:  3 trn_loss:  2.0621798038482666 val_loss:  2.373962879180908
Epoch:  4 trn_loss:  1.9103635549545288 val_loss:  2.206118583679199
Epoch:  5 trn_loss:  1.768867015838623 val_loss:  2.049288749694824
Epoch:  6 trn_loss:  1.6374139785766602 val_loss:  1.9032350778579712
Epoch:  7 trn_loss:  1.5156819820404053 val_loss:  1.7676445245742798
Epoch:  8 trn_loss:  1.4032021760940552 val_loss:  1.641958236694336
Epoch:  9 trn_loss:  1.2995561361312866 val_loss:  1.5258108377456665
Epoch:  10 trn_loss:  1.2043824195861816 val_loss:  1.4188122749328613
Epoch:  11 trn_loss:  1.117397665977478 val_loss:  1.3206236362457275
Epoch:  12 trn_loss:  1.03847074508667 val_loss:  1.2311606407165527
Epoch:  13 trn_loss:  0.9674689173698425 val_loss:  1.1503758430480957
Epoch:  14 trn_loss:  0.9039780497550964 val_loss:  1.0779398679733276
Epoch:  15 trn_loss:  0.847331166267395 val_loss:  1.013088583946228
Epoch:  16 trn_loss:  0.7970718741416931 val_loss:  0.9553214311599731
Epoch:  17 trn_loss:  0.7528130412101746 val_loss:  0.9042891263961792
Epoch:  18 trn_loss:  0.7141948342323303 val_loss:  0.8596652150154114
Epoch:  19 trn_loss:  0.6808304786682129 val_loss:  0.8210365772247314
Epoch:  20 trn_loss:  0.652292013168335 val_loss:  0.7879370450973511
Epoch:  21 trn_loss:  0.6281659007072449 val_loss:  0.7599106431007385
Epoch:  22 trn_loss:  0.6080874800682068 val_loss:  0.7365540266036987
Epoch:  23 trn_loss:  0.5917630791664124 val_loss:  0.7175019979476929
Epoch:  24 trn_loss:  0.5788245797157288 val_loss:  0.702392041683197
Epoch:  25 trn_loss:  0.568841278553009 val_loss:  0.6907913684844971
Epoch:  26 trn_loss:  0.5614078044891357 val_loss:  0.6822508573532104
Epoch:  27 trn_loss:  0.5561255216598511 val_loss:  0.6763443946838379
Epoch:  28 trn_loss:  0.5525712966918945 val_loss:  0.6725645065307617
Epoch:  29 trn_loss:  0.5502450466156006 val_loss:  0.6703435778617859
Epoch:  30 trn_loss:  0.5486972332000732 val_loss:  0.6691796779632568
Epoch:  31 trn_loss:  0.5475630760192871 val_loss:  0.6685906648635864
Epoch:  32 trn_loss:  0.5464308857917786 val_loss:  0.6680597066879272
Epoch:  33 trn_loss:  0.5450358986854553 val_loss:  0.6672754287719727
Epoch:  34 trn_loss:  0.5432081818580627 val_loss:  0.6660330891609192
Epoch:  35 trn_loss:  0.5408215522766113 val_loss:  0.6641598343849182
Epoch:  36 trn_loss:  0.5377958416938782 val_loss:  0.6615045070648193
Epoch:  37 trn_loss:  0.5340939164161682 val_loss:  0.6579532623291016
Epoch:  38 trn_loss:  0.5297320485115051 val_loss:  0.6535050272941589
Epoch:  39 trn_loss:  0.5247765183448792 val_loss:  0.6482319235801697
Epoch:  40 trn_loss:  0.5193665027618408 val_loss:  0.6422970294952393
Epoch:  41 trn_loss:  0.5136898756027222 val_loss:  0.6359198689460754
Epoch:  42 trn_loss:  0.5079445242881775 val_loss:  0.629324197769165
Epoch:  43 trn_loss:  0.5022999048233032 val_loss:  0.6226987838745117
Epoch:  44 trn_loss:  0.496860146522522 val_loss:  0.61616450548172
Epoch:  45 trn_loss:  0.49168145656585693 val_loss:  0.6098062992095947
Epoch:  46 trn_loss:  0.4868048429489136 val_loss:  0.6036902070045471
Epoch:  47 trn_loss:  0.4822656810283661 val_loss:  0.5978918075561523
Epoch:  48 trn_loss:  0.47810056805610657 val_loss:  0.5924769639968872
Epoch:  49 trn_loss:  0.47430527210235596 val_loss:  0.5874605178833008
Epoch:  50 trn_loss:  0.4708606004714966 val_loss:  0.5828387141227722
Epoch:  51 trn_loss:  0.4677314758300781 val_loss:  0.5786116123199463
Epoch:  52 trn_loss:  0.46488624811172485 val_loss:  0.5747863054275513
Epoch:  53 trn_loss:  0.46227461099624634 val_loss:  0.5713388919830322
Epoch:  54 trn_loss:  0.4598667323589325 val_loss:  0.5682305097579956
Epoch:  55 trn_loss:  0.45762667059898376 val_loss:  0.5654037594795227
Epoch:  56 trn_loss:  0.4555044174194336 val_loss:  0.5628034472465515
Epoch:  57 trn_loss:  0.45346057415008545 val_loss:  0.5603811740875244
Epoch:  58 trn_loss:  0.45146486163139343 val_loss:  0.5580896735191345
Epoch:  59 trn_loss:  0.44949114322662354 val_loss:  0.5558891296386719
Epoch:  60 trn_loss:  0.4475208818912506 val_loss:  0.553735613822937
Epoch:  61 trn_loss:  0.44554826617240906 val_loss:  0.5515900254249573
Epoch:  62 trn_loss:  0.4435963034629822 val_loss:  0.5494500994682312
Epoch:  63 trn_loss:  0.44168785214424133 val_loss:  0.547325611114502
Epoch:  64 trn_loss:  0.4398011565208435 val_loss:  0.545186460018158
Epoch:  65 trn_loss:  0.4379242956638336 val_loss:  0.543016791343689
Epoch:  66 trn_loss:  0.436063289642334 val_loss:  0.5408163666725159
Epoch:  67 trn_loss:  0.43423229455947876 val_loss:  0.5385938286781311
Epoch:  68 trn_loss:  0.4324449300765991 val_loss:  0.5363556146621704
Epoch:  69 trn_loss:  0.4307076036930084 val_loss:  0.5341012477874756
Epoch:  70 trn_loss:  0.4290217161178589 val_loss:  0.5318295955657959
Epoch:  71 trn_loss:  0.42738544940948486 val_loss:  0.5295465588569641
Epoch:  72 trn_loss:  0.42579734325408936 val_loss:  0.5272670388221741
Epoch:  73 trn_loss:  0.4242590665817261 val_loss:  0.5250152945518494
Epoch:  74 trn_loss:  0.42277437448501587 val_loss:  0.5228168368339539
Epoch:  75 trn_loss:  0.4213421642780304 val_loss:  0.5206938982009888
Epoch:  76 trn_loss:  0.4199572503566742 val_loss:  0.518661379814148
Epoch:  77 trn_loss:  0.4186100959777832 val_loss:  0.5167261958122253
Epoch:  78 trn_loss:  0.4172911047935486 val_loss:  0.5148859024047852
Epoch:  79 trn_loss:  0.41599199175834656 val_loss:  0.5131323933601379
Epoch:  80 trn_loss:  0.41470766067504883 val_loss:  0.5114537477493286
Epoch:  81 trn_loss:  0.41343551874160767 val_loss:  0.5098355412483215
Epoch:  82 trn_loss:  0.41217443346977234 val_loss:  0.5082621574401855
Epoch:  83 trn_loss:  0.41092362999916077 val_loss:  0.5067175626754761
Epoch:  84 trn_loss:  0.409682959318161 val_loss:  0.5051876306533813
Epoch:  85 trn_loss:  0.4084519147872925 val_loss:  0.5036616921424866
Epoch:  86 trn_loss:  0.4072303771972656 val_loss:  0.5021343231201172
Epoch:  87 trn_loss:  0.4060187041759491 val_loss:  0.5006040334701538
Epoch:  88 trn_loss:  0.4048176109790802 val_loss:  0.4990733563899994
Epoch:  89 trn_loss:  0.40362730622291565 val_loss:  0.4975457787513733
Epoch:  90 trn_loss:  0.4024472236633301 val_loss:  0.49602624773979187
Epoch:  91 trn_loss:  0.4012766480445862 val_loss:  0.49451908469200134
Epoch:  92 trn_loss:  0.40011507272720337 val_loss:  0.49302783608436584
Epoch:  93 trn_loss:  0.3989623785018921 val_loss:  0.4915548861026764
Epoch:  94 trn_loss:  0.39781877398490906 val_loss:  0.49010053277015686
Epoch:  95 trn_loss:  0.3966839611530304 val_loss:  0.48866400122642517
Epoch:  96 trn_loss:  0.39555734395980835 val_loss:  0.48724302649497986
Epoch:  97 trn_loss:  0.3944377899169922 val_loss:  0.4858343303203583
Epoch:  98 trn_loss:  0.3933241367340088 val_loss:  0.4844350814819336
Epoch:  99 trn_loss:  0.3922153115272522 val_loss:  0.48304325342178345
Epoch:  100 trn_loss:  0.39111071825027466 val_loss:  0.4816577732563019
Epoch:  101 trn_loss:  0.3900098204612732 val_loss:  0.4802793264389038
Epoch:  102 trn_loss:  0.3889121413230896 val_loss:  0.4789085388183594
Epoch:  103 trn_loss:  0.3878171145915985 val_loss:  0.4775473177433014
Epoch:  104 trn_loss:  0.38672447204589844 val_loss:  0.4761970341205597
Epoch:  105 trn_loss:  0.38563400506973267 val_loss:  0.47485923767089844
Epoch:  106 trn_loss:  0.3845451772212982 val_loss:  0.4735337495803833
Epoch:  107 trn_loss:  0.38345760107040405 val_loss:  0.47221967577934265
Epoch:  108 trn_loss:  0.3823707401752472 val_loss:  0.4709148406982422
Epoch:  109 trn_loss:  0.3812824487686157 val_loss:  0.4696134924888611
Epoch:  110 trn_loss:  0.38018423318862915 val_loss:  0.46830251812934875
Epoch:  111 trn_loss:  0.37906646728515625 val_loss:  0.4669584631919861
Epoch:  112 trn_loss:  0.3779512345790863 val_loss:  0.4655682146549225
Epoch:  113 trn_loss:  0.37689208984375 val_loss:  0.46413475275039673
Epoch:  114 trn_loss:  0.3758038878440857 val_loss:  0.46280914545059204
Epoch:  115 trn_loss:  0.3746999502182007 val_loss:  0.46161186695098877
Epoch:  116 trn_loss:  0.37359046936035156 val_loss:  0.46050527691841125
Epoch:  117 trn_loss:  0.3724815845489502 val_loss:  0.4594343602657318
Epoch:  118 trn_loss:  0.37137436866760254 val_loss:  0.4583365023136139
Epoch:  119 trn_loss:  0.3702658414840698 val_loss:  0.4571554362773895
Epoch:  120 trn_loss:  0.36915305256843567 val_loss:  0.45587125420570374
Epoch:  121 trn_loss:  0.36803728342056274 val_loss:  0.454497754573822
Epoch:  122 trn_loss:  0.36691898107528687 val_loss:  0.4530557096004486
Epoch:  123 trn_loss:  0.36579588055610657 val_loss:  0.45156681537628174
Epoch:  124 trn_loss:  0.3646661937236786 val_loss:  0.45005205273628235
[125]: 25.0% complete: 
Epoch:  125 trn_loss:  0.3635307550430298 val_loss:  0.44853082299232483
Epoch:  126 trn_loss:  0.3623957931995392 val_loss:  0.44702550768852234
Epoch:  127 trn_loss:  0.3612777590751648 val_loss:  0.44555535912513733
Epoch:  128 trn_loss:  0.3601588010787964 val_loss:  0.44415509700775146
Epoch:  129 trn_loss:  0.3590296804904938 val_loss:  0.4427774250507355
Epoch:  130 trn_loss:  0.3578943908214569 val_loss:  0.44144874811172485
Epoch:  131 trn_loss:  0.3567531704902649 val_loss:  0.44016531109809875
Epoch:  132 trn_loss:  0.3556116819381714 val_loss:  0.43891891837120056
Epoch:  133 trn_loss:  0.3544723093509674 val_loss:  0.4376954138278961
Epoch:  134 trn_loss:  0.3533313572406769 val_loss:  0.4364749789237976
Epoch:  135 trn_loss:  0.35218486189842224 val_loss:  0.43524327874183655
Epoch:  136 trn_loss:  0.3510321080684662 val_loss:  0.43399176001548767
Epoch:  137 trn_loss:  0.3498746156692505 val_loss:  0.4327225089073181
Epoch:  138 trn_loss:  0.34871402382850647 val_loss:  0.43143653869628906
Epoch:  139 trn_loss:  0.34755077958106995 val_loss:  0.4301411807537079
Epoch:  140 trn_loss:  0.34638601541519165 val_loss:  0.42883220314979553
Epoch:  141 trn_loss:  0.3452184796333313 val_loss:  0.4274977743625641
Epoch:  142 trn_loss:  0.34404677152633667 val_loss:  0.4261408746242523
Epoch:  143 trn_loss:  0.3428720533847809 val_loss:  0.42476749420166016
Epoch:  144 trn_loss:  0.3416943848133087 val_loss:  0.4233819246292114
Epoch:  145 trn_loss:  0.3405115008354187 val_loss:  0.4219890832901001
Epoch:  146 trn_loss:  0.3393220901489258 val_loss:  0.4205976128578186
Epoch:  147 trn_loss:  0.33812639117240906 val_loss:  0.419213205575943
Epoch:  148 trn_loss:  0.3369247019290924 val_loss:  0.4178359806537628
Epoch:  149 trn_loss:  0.3357217013835907 val_loss:  0.4164617359638214
Epoch:  150 trn_loss:  0.3345317840576172 val_loss:  0.41507697105407715
Epoch:  151 trn_loss:  0.33334970474243164 val_loss:  0.41363176703453064
Epoch:  152 trn_loss:  0.33216503262519836 val_loss:  0.4121124744415283
Epoch:  153 trn_loss:  0.3309818208217621 val_loss:  0.410550057888031
Epoch:  154 trn_loss:  0.3298012316226959 val_loss:  0.4089862108230591
Epoch:  155 trn_loss:  0.3286168575286865 val_loss:  0.40745192766189575
Epoch:  156 trn_loss:  0.32742610573768616 val_loss:  0.40597015619277954
Epoch:  157 trn_loss:  0.32623380422592163 val_loss:  0.4045504927635193
Epoch:  158 trn_loss:  0.3250555694103241 val_loss:  0.40319791436195374
Epoch:  159 trn_loss:  0.3238782584667206 val_loss:  0.40185144543647766
Epoch:  160 trn_loss:  0.3226983845233917 val_loss:  0.4004645049571991
Epoch:  161 trn_loss:  0.3215292692184448 val_loss:  0.3990371525287628
Epoch:  162 trn_loss:  0.32036474347114563 val_loss:  0.39756420254707336
Epoch:  163 trn_loss:  0.3191930651664734 val_loss:  0.39604923129081726
Epoch:  164 trn_loss:  0.31802111864089966 val_loss:  0.3945232927799225
Epoch:  165 trn_loss:  0.3168593645095825 val_loss:  0.39301925897598267
Epoch:  166 trn_loss:  0.31570154428482056 val_loss:  0.39154523611068726
Epoch:  167 trn_loss:  0.3145410418510437 val_loss:  0.39010822772979736
Epoch:  168 trn_loss:  0.31338343024253845 val_loss:  0.3887196183204651
Epoch:  169 trn_loss:  0.31223204731941223 val_loss:  0.38737714290618896
Epoch:  170 trn_loss:  0.3110830783843994 val_loss:  0.3860596716403961
Epoch:  171 trn_loss:  0.3099355697631836 val_loss:  0.3847445845603943
Epoch:  172 trn_loss:  0.3087936341762543 val_loss:  0.38341742753982544
Epoch:  173 trn_loss:  0.30765849351882935 val_loss:  0.38206425309181213
Epoch:  174 trn_loss:  0.3065260350704193 val_loss:  0.38067328929901123
Epoch:  175 trn_loss:  0.30539628863334656 val_loss:  0.37925249338150024
Epoch:  176 trn_loss:  0.30427277088165283 val_loss:  0.37781980633735657
Epoch:  177 trn_loss:  0.3031555712223053 val_loss:  0.3763926327228546
Epoch:  178 trn_loss:  0.30204248428344727 val_loss:  0.37498268485069275
Epoch:  179 trn_loss:  0.30093351006507874 val_loss:  0.37359747290611267
Epoch:  180 trn_loss:  0.2998305857181549 val_loss:  0.37223851680755615
Epoch:  181 trn_loss:  0.29873356223106384 val_loss:  0.3708992302417755
Epoch:  182 trn_loss:  0.29764142632484436 val_loss:  0.36956894397735596
Epoch:  183 trn_loss:  0.29655537009239197 val_loss:  0.3682395815849304
Epoch:  184 trn_loss:  0.2954767644405365 val_loss:  0.36690670251846313
Epoch:  185 trn_loss:  0.2944047749042511 val_loss:  0.36556732654571533
Epoch:  186 trn_loss:  0.29333898425102234 val_loss:  0.36422160267829895
Epoch:  187 trn_loss:  0.2922804355621338 val_loss:  0.3628738820552826
Epoch:  188 trn_loss:  0.2912302017211914 val_loss:  0.36152979731559753
Epoch:  189 trn_loss:  0.29018768668174744 val_loss:  0.3601939380168915
Epoch:  190 trn_loss:  0.2891528308391571 val_loss:  0.3588702976703644
Epoch:  191 trn_loss:  0.2881257236003876 val_loss:  0.357561320066452
Epoch:  192 trn_loss:  0.2871066629886627 val_loss:  0.356267511844635
Epoch:  193 trn_loss:  0.2860959470272064 val_loss:  0.3549865186214447
Epoch:  194 trn_loss:  0.2850929796695709 val_loss:  0.3537105917930603
Epoch:  195 trn_loss:  0.2840968370437622 val_loss:  0.3524378836154938
Epoch:  196 trn_loss:  0.28310611844062805 val_loss:  0.3511672914028168
Epoch:  197 trn_loss:  0.28211814165115356 val_loss:  0.34989631175994873
Epoch:  198 trn_loss:  0.2811289131641388 val_loss:  0.348615437746048
Epoch:  199 trn_loss:  0.2801346778869629 val_loss:  0.34731459617614746
Epoch:  200 trn_loss:  0.2791338562965393 val_loss:  0.3459857404232025
Epoch:  201 trn_loss:  0.2781164348125458 val_loss:  0.3446342945098877
Epoch:  202 trn_loss:  0.2770686447620392 val_loss:  0.3432539701461792
Epoch:  203 trn_loss:  0.2759862542152405 val_loss:  0.34182971715927124
Epoch:  204 trn_loss:  0.27488598227500916 val_loss:  0.34034764766693115
Epoch:  205 trn_loss:  0.2738039195537567 val_loss:  0.3388190269470215
Epoch:  206 trn_loss:  0.2727872133255005 val_loss:  0.3373345136642456
Epoch:  207 trn_loss:  0.27185651659965515 val_loss:  0.3359776437282562
Epoch:  208 trn_loss:  0.2709851861000061 val_loss:  0.3347732424736023
Epoch:  209 trn_loss:  0.27012521028518677 val_loss:  0.33369576930999756
Epoch:  210 trn_loss:  0.2692703902721405 val_loss:  0.33269867300987244
Epoch:  211 trn_loss:  0.2684323787689209 val_loss:  0.33172979950904846
Epoch:  212 trn_loss:  0.2676165997982025 val_loss:  0.33073389530181885
Epoch:  213 trn_loss:  0.26681196689605713 val_loss:  0.3296637237071991
Epoch:  214 trn_loss:  0.26600706577301025 val_loss:  0.3285127878189087
Epoch:  215 trn_loss:  0.26519882678985596 val_loss:  0.32731980085372925
Epoch:  216 trn_loss:  0.2643932104110718 val_loss:  0.326142817735672
Epoch:  217 trn_loss:  0.2635973393917084 val_loss:  0.3250265121459961
Epoch:  218 trn_loss:  0.2628103494644165 val_loss:  0.3239758610725403
Epoch:  219 trn_loss:  0.2620280086994171 val_loss:  0.3229658305644989
Epoch:  220 trn_loss:  0.26123562455177307 val_loss:  0.32195737957954407
Epoch:  221 trn_loss:  0.260422021150589 val_loss:  0.32090112566947937
Epoch:  222 trn_loss:  0.2595825493335724 val_loss:  0.31976011395454407
Epoch:  223 trn_loss:  0.2587239444255829 val_loss:  0.3185500204563141
Epoch:  224 trn_loss:  0.25782859325408936 val_loss:  0.31729018688201904
Epoch:  225 trn_loss:  0.2569177448749542 val_loss:  0.3160383701324463
Epoch:  226 trn_loss:  0.2560807764530182 val_loss:  0.31491443514823914
Epoch:  227 trn_loss:  0.2553540766239166 val_loss:  0.31398826837539673
Epoch:  228 trn_loss:  0.25469961762428284 val_loss:  0.31317102909088135
Epoch:  229 trn_loss:  0.25404104590415955 val_loss:  0.31229326128959656
Epoch:  230 trn_loss:  0.25338056683540344 val_loss:  0.31130754947662354
Epoch:  231 trn_loss:  0.2527371048927307 val_loss:  0.3102620542049408
Epoch:  232 trn_loss:  0.2521231770515442 val_loss:  0.30923235416412354
Epoch:  233 trn_loss:  0.2515329122543335 val_loss:  0.30827850103378296
Epoch:  234 trn_loss:  0.25095123052597046 val_loss:  0.30743053555488586
Epoch:  235 trn_loss:  0.2503703534603119 val_loss:  0.30669260025024414
Epoch:  236 trn_loss:  0.24979250133037567 val_loss:  0.3060412108898163
Epoch:  237 trn_loss:  0.2492249310016632 val_loss:  0.30543121695518494
Epoch:  238 trn_loss:  0.24867184460163116 val_loss:  0.3048091232776642
Epoch:  239 trn_loss:  0.2481302171945572 val_loss:  0.3041329085826874
Epoch:  240 trn_loss:  0.24759486317634583 val_loss:  0.3033941984176636
Epoch:  241 trn_loss:  0.24706517159938812 val_loss:  0.30262088775634766
Epoch:  242 trn_loss:  0.24654310941696167 val_loss:  0.3018553853034973
Epoch:  243 trn_loss:  0.24602797627449036 val_loss:  0.3011277914047241
Epoch:  244 trn_loss:  0.24551725387573242 val_loss:  0.30044272541999817
Epoch:  245 trn_loss:  0.2450115978717804 val_loss:  0.29978200793266296
Epoch:  246 trn_loss:  0.24451200664043427 val_loss:  0.2991197407245636
Epoch:  247 trn_loss:  0.24401675164699554 val_loss:  0.2984353303909302
Epoch:  248 trn_loss:  0.24352410435676575 val_loss:  0.29772838950157166
Epoch:  249 trn_loss:  0.24303531646728516 val_loss:  0.2970202565193176
[250]: 50.0% complete: 
Epoch:  250 trn_loss:  0.24255527555942535 val_loss:  0.2963433861732483
Epoch:  251 trn_loss:  0.24206863343715668 val_loss:  0.29570987820625305
Epoch:  252 trn_loss:  0.24155765771865845 val_loss:  0.2951059937477112
Epoch:  253 trn_loss:  0.24101734161376953 val_loss:  0.2945045232772827
Epoch:  254 trn_loss:  0.2404448240995407 val_loss:  0.2938810884952545
Epoch:  255 trn_loss:  0.23986713588237762 val_loss:  0.29319679737091064
Epoch:  256 trn_loss:  0.23932965099811554 val_loss:  0.29243671894073486
Epoch:  257 trn_loss:  0.23883400857448578 val_loss:  0.2915607690811157
Epoch:  258 trn_loss:  0.23834307491779327 val_loss:  0.29051241278648376
Epoch:  259 trn_loss:  0.2378787249326706 val_loss:  0.28941309452056885
Epoch:  260 trn_loss:  0.23745594918727875 val_loss:  0.28850969672203064
Epoch:  261 trn_loss:  0.23703080415725708 val_loss:  0.28793296217918396
Epoch:  262 trn_loss:  0.23658710718154907 val_loss:  0.28760063648223877
Epoch:  263 trn_loss:  0.23614028096199036 val_loss:  0.2873496413230896
Epoch:  264 trn_loss:  0.23569999635219574 val_loss:  0.2870151400566101
Epoch:  265 trn_loss:  0.23526471853256226 val_loss:  0.28650715947151184
Epoch:  266 trn_loss:  0.2348332405090332 val_loss:  0.2858642339706421
Epoch:  267 trn_loss:  0.23440520465373993 val_loss:  0.28519999980926514
Epoch:  268 trn_loss:  0.2339765429496765 val_loss:  0.28462809324264526
Epoch:  269 trn_loss:  0.23354625701904297 val_loss:  0.2841995060443878
Epoch:  270 trn_loss:  0.23312003910541534 val_loss:  0.28388309478759766
Epoch:  271 trn_loss:  0.2327006459236145 val_loss:  0.28358399868011475
Epoch:  272 trn_loss:  0.23229354619979858 val_loss:  0.28317418694496155
Epoch:  273 trn_loss:  0.23188574612140656 val_loss:  0.28253161907196045
Epoch:  274 trn_loss:  0.2314797043800354 val_loss:  0.2817170321941376
Epoch:  275 trn_loss:  0.2310832291841507 val_loss:  0.2809235155582428
Epoch:  276 trn_loss:  0.2306908518075943 val_loss:  0.28030288219451904
Epoch:  277 trn_loss:  0.23030199110507965 val_loss:  0.27988412976264954
Epoch:  278 trn_loss:  0.22991688549518585 val_loss:  0.27957284450531006
Epoch:  279 trn_loss:  0.22952896356582642 val_loss:  0.2792253792285919
Epoch:  280 trn_loss:  0.2291400134563446 val_loss:  0.27877840399742126
Epoch:  281 trn_loss:  0.2287563681602478 val_loss:  0.2782689332962036
Epoch:  282 trn_loss:  0.22837479412555695 val_loss:  0.2777637243270874
Epoch:  283 trn_loss:  0.2279926985502243 val_loss:  0.27730119228363037
Epoch:  284 trn_loss:  0.22761262953281403 val_loss:  0.27687132358551025
Epoch:  285 trn_loss:  0.22723351418972015 val_loss:  0.2764291763305664
Epoch:  286 trn_loss:  0.2268543839454651 val_loss:  0.2759547233581543
Epoch:  287 trn_loss:  0.22647874057292938 val_loss:  0.2754766643047333
Epoch:  288 trn_loss:  0.22610634565353394 val_loss:  0.27503639459609985
Epoch:  289 trn_loss:  0.22573433816432953 val_loss:  0.2746444344520569
Epoch:  290 trn_loss:  0.22536462545394897 val_loss:  0.27427351474761963
Epoch:  291 trn_loss:  0.22499796748161316 val_loss:  0.2738790512084961
Epoch:  292 trn_loss:  0.2246318906545639 val_loss:  0.27343907952308655
Epoch:  293 trn_loss:  0.22426679730415344 val_loss:  0.2729776203632355
Epoch:  294 trn_loss:  0.22390365600585938 val_loss:  0.27253931760787964
Epoch:  295 trn_loss:  0.22354157269001007 val_loss:  0.2721422612667084
Epoch:  296 trn_loss:  0.22318077087402344 val_loss:  0.271761953830719
Epoch:  297 trn_loss:  0.22282131016254425 val_loss:  0.27135583758354187
Epoch:  298 trn_loss:  0.2224620133638382 val_loss:  0.2709050178527832
Epoch:  299 trn_loss:  0.2221025675535202 val_loss:  0.2704373598098755
Epoch:  300 trn_loss:  0.22174352407455444 val_loss:  0.27000412344932556
Epoch:  301 trn_loss:  0.22138440608978271 val_loss:  0.26963168382644653
Epoch:  302 trn_loss:  0.22102545201778412 val_loss:  0.26929858326911926
Epoch:  303 trn_loss:  0.22066687047481537 val_loss:  0.26895081996917725
Epoch:  304 trn_loss:  0.22030769288539886 val_loss:  0.26855164766311646
Epoch:  305 trn_loss:  0.21994832158088684 val_loss:  0.2681146562099457
Epoch:  306 trn_loss:  0.2195892632007599 val_loss:  0.2676840126514435
Epoch:  307 trn_loss:  0.21922995150089264 val_loss:  0.2672986090183258
Epoch:  308 trn_loss:  0.2188708335161209 val_loss:  0.2669552266597748
Epoch:  309 trn_loss:  0.2185129076242447 val_loss:  0.26661258935928345
Epoch:  310 trn_loss:  0.21815569698810577 val_loss:  0.2662247121334076
Epoch:  311 trn_loss:  0.21779949963092804 val_loss:  0.2658032774925232
Epoch:  312 trn_loss:  0.21744471788406372 val_loss:  0.2653878927230835
Epoch:  313 trn_loss:  0.21709108352661133 val_loss:  0.26500627398490906
Epoch:  314 trn_loss:  0.21673835813999176 val_loss:  0.26464831829071045
Epoch:  315 trn_loss:  0.2163860946893692 val_loss:  0.2642895579338074
Epoch:  316 trn_loss:  0.21603378653526306 val_loss:  0.2639158070087433
Epoch:  317 trn_loss:  0.21568141877651215 val_loss:  0.26353919506073
Epoch:  318 trn_loss:  0.2153288722038269 val_loss:  0.2631806433200836
Epoch:  319 trn_loss:  0.21497634053230286 val_loss:  0.26284387707710266
Epoch:  320 trn_loss:  0.21462416648864746 val_loss:  0.26251235604286194
Epoch:  321 trn_loss:  0.21427291631698608 val_loss:  0.2621670365333557
Epoch:  322 trn_loss:  0.21392303705215454 val_loss:  0.26180770993232727
Epoch:  323 trn_loss:  0.21357494592666626 val_loss:  0.2614489793777466
Epoch:  324 trn_loss:  0.21322858333587646 val_loss:  0.2611008286476135
Epoch:  325 trn_loss:  0.21288380026817322 val_loss:  0.26075613498687744
Epoch:  326 trn_loss:  0.21254049241542816 val_loss:  0.26039865612983704
Epoch:  327 trn_loss:  0.21219852566719055 val_loss:  0.26002639532089233
Epoch:  328 trn_loss:  0.2118578851222992 val_loss:  0.2596556544303894
Epoch:  329 trn_loss:  0.2115185409784317 val_loss:  0.25930818915367126
Epoch:  330 trn_loss:  0.21118031442165375 val_loss:  0.2589772939682007
Epoch:  331 trn_loss:  0.2108430713415146 val_loss:  0.25864240527153015
Epoch:  332 trn_loss:  0.21050676703453064 val_loss:  0.25828662514686584
Epoch:  333 trn_loss:  0.21017129719257355 val_loss:  0.2579200565814972
Epoch:  334 trn_loss:  0.20983658730983734 val_loss:  0.2575629949569702
Epoch:  335 trn_loss:  0.20950248837471008 val_loss:  0.2572227418422699
Epoch:  336 trn_loss:  0.20916877686977386 val_loss:  0.2568882405757904
Epoch:  337 trn_loss:  0.2088351547718048 val_loss:  0.2565457820892334
Epoch:  338 trn_loss:  0.20850083231925964 val_loss:  0.25619640946388245
Epoch:  339 trn_loss:  0.20816394686698914 val_loss:  0.2558504343032837
Epoch:  340 trn_loss:  0.20782269537448883 val_loss:  0.2555161416530609
Epoch:  341 trn_loss:  0.2074858546257019 val_loss:  0.2552073299884796
Epoch:  342 trn_loss:  0.2071499228477478 val_loss:  0.2549232244491577
Epoch:  343 trn_loss:  0.20681342482566833 val_loss:  0.25466665625572205
Epoch:  344 trn_loss:  0.2064841091632843 val_loss:  0.2544088661670685
Epoch:  345 trn_loss:  0.2061498761177063 val_loss:  0.2540683448314667
Epoch:  346 trn_loss:  0.2058153748512268 val_loss:  0.25363051891326904
Epoch:  347 trn_loss:  0.20548364520072937 val_loss:  0.25318506360054016
Epoch:  348 trn_loss:  0.20514941215515137 val_loss:  0.2528451383113861
Epoch:  349 trn_loss:  0.20481668412685394 val_loss:  0.25261709094047546
Epoch:  350 trn_loss:  0.2044857144355774 val_loss:  0.252371221780777
Epoch:  351 trn_loss:  0.20415298640727997 val_loss:  0.25198906660079956
Epoch:  352 trn_loss:  0.20382291078567505 val_loss:  0.2515212297439575
Epoch:  353 trn_loss:  0.20349590480327606 val_loss:  0.2511248290538788
Epoch:  354 trn_loss:  0.2031671404838562 val_loss:  0.2508711814880371
Epoch:  355 trn_loss:  0.20283912122249603 val_loss:  0.2506723701953888
Epoch:  356 trn_loss:  0.20251111686229706 val_loss:  0.2503918707370758
Epoch:  357 trn_loss:  0.20218093693256378 val_loss:  0.2500110864639282
Epoch:  358 trn_loss:  0.2018512785434723 val_loss:  0.24962854385375977
Epoch:  359 trn_loss:  0.2015211135149002 val_loss:  0.24931880831718445
Epoch:  360 trn_loss:  0.20118878781795502 val_loss:  0.24905292689800262
Epoch:  361 trn_loss:  0.20085155963897705 val_loss:  0.24876874685287476
Epoch:  362 trn_loss:  0.20050734281539917 val_loss:  0.2484496682882309
Epoch:  363 trn_loss:  0.20016363263130188 val_loss:  0.24811632931232452
Epoch:  364 trn_loss:  0.19982467591762543 val_loss:  0.24776652455329895
Epoch:  365 trn_loss:  0.19948498904705048 val_loss:  0.24737590551376343
Epoch:  366 trn_loss:  0.1991385817527771 val_loss:  0.24693505465984344
Epoch:  367 trn_loss:  0.19879484176635742 val_loss:  0.24649609625339508
Epoch:  368 trn_loss:  0.19846013188362122 val_loss:  0.24605989456176758
Epoch:  369 trn_loss:  0.1981462836265564 val_loss:  0.24560478329658508
Epoch:  370 trn_loss:  0.1978347897529602 val_loss:  0.24520240724086761
Epoch:  371 trn_loss:  0.19750672578811646 val_loss:  0.24494120478630066
Epoch:  372 trn_loss:  0.1971699446439743 val_loss:  0.24484682083129883
Epoch:  373 trn_loss:  0.19683462381362915 val_loss:  0.24474705755710602
Epoch:  374 trn_loss:  0.19650715589523315 val_loss:  0.24444682896137238
[375]: 75.0% complete: 
Epoch:  375 trn_loss:  0.19618754088878632 val_loss:  0.24398475885391235
Epoch:  376 trn_loss:  0.1958659291267395 val_loss:  0.24355639517307281
Epoch:  377 trn_loss:  0.19553729891777039 val_loss:  0.243229940533638
Epoch:  378 trn_loss:  0.195206880569458 val_loss:  0.2429027259349823
Epoch:  379 trn_loss:  0.19487547874450684 val_loss:  0.24249979853630066
Epoch:  380 trn_loss:  0.19454486668109894 val_loss:  0.24209493398666382
Epoch:  381 trn_loss:  0.19421841204166412 val_loss:  0.24175794422626495
Epoch:  382 trn_loss:  0.1938941925764084 val_loss:  0.24146339297294617
Epoch:  383 trn_loss:  0.19356435537338257 val_loss:  0.24116311967372894
Epoch:  384 trn_loss:  0.1932315081357956 val_loss:  0.24088753759860992
Epoch:  385 trn_loss:  0.19290010631084442 val_loss:  0.2406534105539322
Epoch:  386 trn_loss:  0.19257010519504547 val_loss:  0.24038881063461304
Epoch:  387 trn_loss:  0.19223640859127045 val_loss:  0.2400466799736023
Epoch:  388 trn_loss:  0.1918950378894806 val_loss:  0.23967748880386353
Epoch:  389 trn_loss:  0.19152668118476868 val_loss:  0.23932231962680817
Epoch:  390 trn_loss:  0.19112028181552887 val_loss:  0.2389666736125946
Epoch:  391 trn_loss:  0.19076405465602875 val_loss:  0.23864436149597168
Epoch:  392 trn_loss:  0.19045135378837585 val_loss:  0.23815250396728516
Epoch:  393 trn_loss:  0.19011946022510529 val_loss:  0.23765437304973602
Epoch:  394 trn_loss:  0.18979066610336304 val_loss:  0.2374032884836197
Epoch:  395 trn_loss:  0.18947464227676392 val_loss:  0.2372097223997116
Epoch:  396 trn_loss:  0.18915735185146332 val_loss:  0.23688620328903198
Epoch:  397 trn_loss:  0.18883727490901947 val_loss:  0.23657719790935516
Epoch:  398 trn_loss:  0.18851631879806519 val_loss:  0.236392080783844
Epoch:  399 trn_loss:  0.18820549547672272 val_loss:  0.2361857295036316
Epoch:  400 trn_loss:  0.18789544701576233 val_loss:  0.23585309088230133
Epoch:  401 trn_loss:  0.18758566677570343 val_loss:  0.2355158030986786
Epoch:  402 trn_loss:  0.18727771937847137 val_loss:  0.23521322011947632
Epoch:  403 trn_loss:  0.1869785040616989 val_loss:  0.2348242551088333
Epoch:  404 trn_loss:  0.18668362498283386 val_loss:  0.23437632620334625
Epoch:  405 trn_loss:  0.18638809025287628 val_loss:  0.2340860366821289
Epoch:  406 trn_loss:  0.18609514832496643 val_loss:  0.23396486043930054
Epoch:  407 trn_loss:  0.18581047654151917 val_loss:  0.23378780484199524
Epoch:  408 trn_loss:  0.1855272650718689 val_loss:  0.233463317155838
Epoch:  409 trn_loss:  0.18524475395679474 val_loss:  0.233132466673851
Epoch:  410 trn_loss:  0.18496619164943695 val_loss:  0.23285610973834991
Epoch:  411 trn_loss:  0.18469101190567017 val_loss:  0.23256731033325195
Epoch:  412 trn_loss:  0.18441864848136902 val_loss:  0.23227037489414215
Epoch:  413 trn_loss:  0.18414820730686188 val_loss:  0.23201830685138702
Epoch:  414 trn_loss:  0.18388009071350098 val_loss:  0.23176172375679016
Epoch:  415 trn_loss:  0.18361397087574005 val_loss:  0.23144537210464478
Epoch:  416 trn_loss:  0.1833503097295761 val_loss:  0.2311382293701172
Epoch:  417 trn_loss:  0.18308880925178528 val_loss:  0.23090116679668427
Epoch:  418 trn_loss:  0.18282891809940338 val_loss:  0.2306654155254364
Epoch:  419 trn_loss:  0.18257054686546326 val_loss:  0.23037713766098022
Epoch:  420 trn_loss:  0.1823144555091858 val_loss:  0.23009854555130005
Epoch:  421 trn_loss:  0.1820598989725113 val_loss:  0.22986003756523132
Epoch:  422 trn_loss:  0.181807279586792 val_loss:  0.22959116101264954
Epoch:  423 trn_loss:  0.1815556436777115 val_loss:  0.22928351163864136
Epoch:  424 trn_loss:  0.181304931640625 val_loss:  0.22901540994644165
Epoch:  425 trn_loss:  0.1810559183359146 val_loss:  0.22879011929035187
Epoch:  426 trn_loss:  0.18080803751945496 val_loss:  0.2285410612821579
Epoch:  427 trn_loss:  0.18056048452854156 val_loss:  0.22827471792697906
Epoch:  428 trn_loss:  0.1803124099969864 val_loss:  0.22801566123962402
Epoch:  429 trn_loss:  0.18006402254104614 val_loss:  0.22773927450180054
Epoch:  430 trn_loss:  0.17981569468975067 val_loss:  0.22744318842887878
Epoch:  431 trn_loss:  0.17956867814064026 val_loss:  0.22717253863811493
Epoch:  432 trn_loss:  0.17932309210300446 val_loss:  0.22690919041633606
Epoch:  433 trn_loss:  0.17907869815826416 val_loss:  0.22659462690353394
Epoch:  434 trn_loss:  0.17883695662021637 val_loss:  0.22626107931137085
Epoch:  435 trn_loss:  0.17859607934951782 val_loss:  0.2259606570005417
Epoch:  436 trn_loss:  0.1783568114042282 val_loss:  0.22565586864948273
Epoch:  437 trn_loss:  0.17811758816242218 val_loss:  0.22532765567302704
Epoch:  438 trn_loss:  0.17787858843803406 val_loss:  0.2250095009803772
Epoch:  439 trn_loss:  0.17764179408550262 val_loss:  0.22471009194850922
Epoch:  440 trn_loss:  0.1774030476808548 val_loss:  0.2244028002023697
Epoch:  441 trn_loss:  0.17715275287628174 val_loss:  0.22411182522773743
Epoch:  442 trn_loss:  0.1769069880247116 val_loss:  0.22386424243450165
Epoch:  443 trn_loss:  0.1766706109046936 val_loss:  0.22361686825752258
Epoch:  444 trn_loss:  0.17642784118652344 val_loss:  0.22343164682388306
Epoch:  445 trn_loss:  0.17619864642620087 val_loss:  0.22327958047389984
Epoch:  446 trn_loss:  0.17596647143363953 val_loss:  0.22299650311470032
Epoch:  447 trn_loss:  0.17572930455207825 val_loss:  0.22267039120197296
Epoch:  448 trn_loss:  0.17550039291381836 val_loss:  0.22245335578918457
Epoch:  449 trn_loss:  0.17527133226394653 val_loss:  0.22224394977092743
Epoch:  450 trn_loss:  0.17503784596920013 val_loss:  0.2219763547182083
Epoch:  451 trn_loss:  0.17480938136577606 val_loss:  0.22172437608242035
Epoch:  452 trn_loss:  0.1745833307504654 val_loss:  0.22145769000053406
Epoch:  453 trn_loss:  0.1743534952402115 val_loss:  0.22114348411560059
Epoch:  454 trn_loss:  0.17412689328193665 val_loss:  0.22088061273097992
Epoch:  455 trn_loss:  0.17390288412570953 val_loss:  0.22068282961845398
Epoch:  456 trn_loss:  0.1736764758825302 val_loss:  0.22046241164207458
Epoch:  457 trn_loss:  0.17345207929611206 val_loss:  0.22021260857582092
Epoch:  458 trn_loss:  0.1732291728258133 val_loss:  0.21995636820793152
Epoch:  459 trn_loss:  0.17300580441951752 val_loss:  0.21966341137886047
Epoch:  460 trn_loss:  0.17278391122817993 val_loss:  0.2193962037563324
Epoch:  461 trn_loss:  0.17256297171115875 val_loss:  0.21920031309127808
Epoch:  462 trn_loss:  0.17234228551387787 val_loss:  0.21899782121181488
Epoch:  463 trn_loss:  0.17212222516536713 val_loss:  0.21875229477882385
Epoch:  464 trn_loss:  0.17190323770046234 val_loss:  0.2185194194316864
Epoch:  465 trn_loss:  0.17168442904949188 val_loss:  0.2182903289794922
Epoch:  466 trn_loss:  0.17146620154380798 val_loss:  0.21803754568099976
Epoch:  467 trn_loss:  0.17124874889850616 val_loss:  0.21779878437519073
Epoch:  468 trn_loss:  0.17103147506713867 val_loss:  0.21756181120872498
Epoch:  469 trn_loss:  0.17081469297409058 val_loss:  0.21729141473770142
Epoch:  470 trn_loss:  0.17059840261936188 val_loss:  0.21702533960342407
Epoch:  471 trn_loss:  0.17038236558437347 val_loss:  0.21678856015205383
Epoch:  472 trn_loss:  0.17016710340976715 val_loss:  0.216550812125206
Epoch:  473 trn_loss:  0.16995276510715485 val_loss:  0.2163187563419342
Epoch:  474 trn_loss:  0.16973935067653656 val_loss:  0.21609145402908325
Epoch:  475 trn_loss:  0.16952665150165558 val_loss:  0.21582897007465363
Epoch:  476 trn_loss:  0.16931456327438354 val_loss:  0.21554973721504211
Epoch:  477 trn_loss:  0.16910308599472046 val_loss:  0.21529196202754974
Epoch:  478 trn_loss:  0.16889230906963348 val_loss:  0.2150442898273468
Epoch:  479 trn_loss:  0.16868223249912262 val_loss:  0.21480196714401245
Epoch:  480 trn_loss:  0.1684730052947998 val_loss:  0.21457365155220032
Epoch:  481 trn_loss:  0.1682647317647934 val_loss:  0.21434038877487183
Epoch:  482 trn_loss:  0.16805730760097504 val_loss:  0.21409696340560913
Epoch:  483 trn_loss:  0.16785070300102234 val_loss:  0.21386009454727173
Epoch:  484 trn_loss:  0.16764479875564575 val_loss:  0.21362394094467163
Epoch:  485 trn_loss:  0.1674395352602005 val_loss:  0.21338942646980286
Epoch:  486 trn_loss:  0.16723500192165375 val_loss:  0.21315565705299377
Epoch:  487 trn_loss:  0.16703113913536072 val_loss:  0.21291136741638184
Epoch:  488 trn_loss:  0.16682808101177216 val_loss:  0.21266674995422363
Epoch:  489 trn_loss:  0.16662584245204926 val_loss:  0.21243375539779663
Epoch:  490 trn_loss:  0.1664242446422577 val_loss:  0.21220135688781738
Epoch:  491 trn_loss:  0.16622327268123627 val_loss:  0.21197472512722015
Epoch:  492 trn_loss:  0.1660229116678238 val_loss:  0.21175330877304077
Epoch:  493 trn_loss:  0.1658230572938919 val_loss:  0.2115258425474167
Epoch:  494 trn_loss:  0.16562382876873016 val_loss:  0.21130114793777466
Epoch:  495 trn_loss:  0.16542519629001617 val_loss:  0.2110806107521057
Epoch:  496 trn_loss:  0.16522715985774994 val_loss:  0.2108606994152069
Epoch:  497 trn_loss:  0.16502973437309265 val_loss:  0.21064303815364838
Epoch:  498 trn_loss:  0.16483283042907715 val_loss:  0.21041889488697052
Epoch:  499 trn_loss:  0.16463641822338104 val_loss:  0.21018409729003906
Training time = 174.25515794754028 seconds

Plotting: MSE Loss for Training and Validation

In order to understand how well the model has trained we plot the training loss and validation loss as a function of Epoch in Figure 2. Figure 2 shows the MSE loss for training (blue) and validation (orange) as a function of epoch.

In [11]:
plt.plot(np.arange(len(train_loss[:])), train_loss[:], color="blue")
plt.plot(np.arange(len(train_loss[:]))+1, valdn_loss[:], color="orange")
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.show()

Figure 2: Training and Validation MSE loss (blue, orange) as a function of Epoch.


Step 4: Testing the Model

Now that the model has been trained, testing the model is a computationally cheap proceedure. As before, we choose the data using DEMdata, and load with DataLoader. Using valtest_model, the DeepEM map is created ${\texttt{output = model(img)}}$, and the MSE loss calculated as during training.

In [12]:
dem_data_test = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='test')
dem_loader = DataLoader(dem_data_test, batch_size=1)

t0=time.time() #Timing how long it takes to predict the DEMs
dummy, test_loss, dem_pred, dem_in_test = valtest_model(dem_loader, criterion)
performance = "Number of DEM solutions per second = {0}".format((y_test.shape[2]*y_test.shape[3])/(time.time()-t0))

print(performance)
Number of DEM solutions per second = 2683542.159541545

Plotting: AIA, Basis Pursuit, DeepEM

With the DeepEM map calculated, we can now compare the solutions obtained by Basis Pursuit and DeepEM. Figure 3 is similar to Figure 1 with an additional row corresponding to the solutions for DeepEM. Figure 3 shows SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins from Basis Pursuit (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below (middle). The bottom row shows the DeepEM solutions that correspond to the same bins as the Basis Pursuit solutions. DeepEM provides solutions that are similar to Basis Pursuit, but importantly, provides DEM solutions for every pixel.

In [13]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_in_test[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_in_test[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_in_test[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_pred[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_pred[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_pred[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)

Figure 3: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top), with the corresponding DEM bins from Basis Pursuit (chosen at the peak sensitivity of each of the SDO/AIA channels) shown below (middle). The bottom row shows the DeepEM solutions that correspond to the same bins as the Basis Pursuit solutions. DeepEM provides solutions that are similar to Basis Pursuit, but importantly, provides DEM solutions for every pixel.


Furthermore, as we have the original Basis Pursuit DEM solutions ("the ground truth"), we can compare the average DEM from Basis Pursuit to the average DEM from DeepEM, as they should be similar. Figure 4 shows the average Basis Pursuit DEM (black curve) and the DeepEM solution (light blue bars/dotted line).

In [14]:
def PlotTotalEM(em_unscaled, em_pred_unscaled, lgtaxis, status):
    mask = np.zeros([status.shape[0],status.shape[1]])
    mask[np.where(status == 0.0)] = 1.0
    nmask = np.sum(mask)
    
    EM_tru_sum = np.zeros([lgtaxis.size])
    EM_inv_sum = np.zeros([lgtaxis.size])
    
    for i in range(lgtaxis.size):
        EM_tru_sum[i] = np.sum(em_unscaled[0,i,:,:]*mask)/nmask
        EM_inv_sum[i] = np.sum(em_pred_unscaled[0,i,:,:]*mask)/nmask
        
    fig = plt.figure   
    plt.plot(lgtaxis,EM_tru_sum, linewidth=3, color="black")
    plt.plot(lgtaxis,EM_inv_sum, linewidth=3, color="lightblue", linestyle='--')
    plt.tick_params(axis='both', which='major')#, labelsize=16)
    plt.tick_params(axis='both', which='minor')#, labelsize=16)
    
    dlogT = lgtaxis[1]-lgtaxis[0]
    plt.bar(lgtaxis-0.5*dlogT, EM_inv_sum, dlogT, linewidth=2, color='lightblue')
    
    plt.xlim(lgtaxis[0]-0.5*dlogT, lgtaxis.max()+0.5*dlogT)
    plt.xticks(np.arange(np.min(lgtaxis), np.max(lgtaxis),2*dlogT))
    plt.ylim(1e24,1e27)
    plt.yscale('log')
    plt.xlabel('log$_{10}$T [K]')
    plt.ylabel('Mean Emission Measure [cm$^{-5}$]')
    plt.title('Basis Pursuit (curve) vs. DeepEM (bars)')
    
    plt.show()
    return EM_inv_sum, EM_tru_sum
In [15]:
em_unscaled = em_unscale(dem_in_test.detach().cpu().numpy())
em_pred_unscaled = em_unscale(dem_pred.detach().cpu().numpy())
status = np.zeros([512,512])
                   
EMinv, EMTru = PlotTotalEM(em_unscaled,em_pred_unscaled,lgtaxis,status)

Figure 4: Average Basis Pursuit DEM (black line) against the Average DeepEM solution (light blue bars/dotted line). It is clear that this simple implementation of DeepEM provides, on average, DEMs that are similar to Basis Pursuit (Cheung et al 2015).


Step 5: Synthesize SDO/AIA Observations

Finally, it is also of interest to reconstruct the SDO/AIA observations from both the Basis Pursuit, and DeepEM solutions.

We are able to pose the problem of reconstructing the SDO/AIA observations from the DEM as a 1x1 2D Convolution. We first define the weights as the response functions of each channel, and set the biases to $zero$. By convolving the unscaled DEM at each pixel with the 6 filters (one for each SDO/AIA response function), we can recover the SDO/AIA observations.

In [16]:
# We first load the AIA response functions:
cl = np.load('./DeepEM_Data/chianti_lines_AIA.npy')
In [17]:
# Used Conv2d to convolve?? every pixel (18x1x1) by the 6 response functions
# to return a set of observed fluxes in each channel (6x1x1)
dem2aia = nn.Conv2d(18, 6, kernel_size=1).cuda()

chianti_lines_2 = torch.zeros(6,18,1,1).cuda()
biases = torch.zeros(6).cuda()

# set the weights to each of the SDO/AIA response functions and biases to zero
for i, p in enumerate(dem2aia.parameters()):
    if i == 0:
        p.data = Variable(torch.from_numpy(cl).type(torch.cuda.FloatTensor))
    else:
        p.data = biases 
In [18]:
AIA_out = img_scale(dem2aia(Variable(em_unscale(dem_in_test))).detach().cpu().numpy())
AIA_out_DeepEM = img_scale(dem2aia(Variable(em_unscale(dem_pred))).detach().cpu().numpy())

Plotting SDO/AIA Observations and Synthetic Observations

In [19]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'Basis Pursuit Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'Basis Pursuit Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out_DeepEM[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'DeepEM Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out_DeepEM[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'DeepEM Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    

Figure 5: Left to Right: SDO/AIA images in 171 Ã…, 211 Ã…, and 94 Ã… (top) with the corresponding synthesised observations from Basis Pursuit (middle) and DeepEM (bottom). DeepEM provides synthetic observations that are similar to Basis Pursuit, with the addition of solutions where the basis pursuit solution was $zero$.


Discussion

This chapter has provided an example of how a 1x1 2D Convolutional Neural Network can be used to improve computational cost for DEM inversion. Future improvements to DeepEM can come in a few ways:

First, by using both the original, and synthesised data from the DEM, the ability of the DEM to recover the original or supplementary data can be used as a additional term in the loss function. Furthermore, we could use a number of additional data to further constrain the DEMs:

  • Use SDO/AIA Data to correct the DEMs
  • Use MEGS-A EUV to correct the DEMs
  • Use Hard X-ray observations to correct the DEMs

Appendix A: What has the CNN learned about our training set?

If we say that our training set is now our test set, we can see how much the CNN has learned about the training data:

In [21]:
X_test = X_train 
y_test = y_train
In [22]:
dem_data_test = DEMdata(X_train, y_train, X_test, y_test, X_val, y_val, split='test')
dem_loader = DataLoader(dem_data_test, batch_size=1)

dummy, test_loss, dem_pred_trn, dem_in_test_trn = valtest_model(dem_loader, criterion)
In [23]:
AIA_out = img_scale(dem2aia(Variable(em_unscale(dem_in_test_trn))).detach().cpu().numpy())
AIA_out_DeepEM = img_scale(dem2aia(Variable(em_unscale(dem_pred_trn))).detach().cpu().numpy())
In [25]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_in_test_trn[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_in_test_trn[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_in_test_trn[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'Basis Pursuit DEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(dem_pred_trn[0,8,:,:].cpu().detach().numpy(),vmin=0.25,vmax=10,cmap='viridis')
ax[1].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 6.3', color="white", size='large')
#ax[1].scatter(x=[400], y=[80], c='w', s=50, marker='x') 
ax[0].imshow(dem_pred_trn[0,4,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[0].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 5.9', color="white", size='large')
ax[2].imshow(dem_pred_trn[0,15,:,:].cpu().detach().numpy(),vmin=0.01,vmax=3,cmap='viridis')
ax[2].text(5, 512.-7.5, 'DeepEM log$_{10}$T ~ 7.0', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
In [24]:
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(X_test[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, '${\it SDO}$/AIA 211 $\AA$', color="white", size='large')
ax[0].imshow(X_test[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, '${\it SDO}$/AIA 171 $\AA$', color="white", size='large')
ax[2].imshow(X_test[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, '${\it SDO}$/AIA 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'Basis Pursuit Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'Basis Pursuit Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'Basis Pursuit Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    
fig,ax=plt.subplots(ncols=3,figsize=(9*2,9))

ax[1].imshow(AIA_out_DeepEM[0,4,:,:],vmin=0.25,vmax=25,cmap='Greys_r')
ax[1].text(5, 512.-7.5, 'DeepEM Synthesized 211 $\AA$', color="white", size='large')
ax[0].imshow(AIA_out_DeepEM[0,2,:,:],vmin=0.01,vmax=30,cmap='Greys_r')
ax[0].text(5, 512.-7.5, 'DeepEM Synthesized 171 $\AA$', color="white", size='large')
ax[2].imshow(AIA_out_DeepEM[0,0,:,:],vmin=0.01,vmax=3,cmap='Greys_r')
ax[2].text(5, 512.-7.5, 'DeepEM Synthesized 94 $\AA$', color="white", size='large')

for axes in ax:
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
In [ ]: